#include "POF/Loss_generator_sample_grad.h"

double w_cnt = 0.01;

void set_w_cnt(double new_w_cnt) {
    w_cnt = new_w_cnt;
}

double calc_approx_loss(NN_data* nn_data, Indata* indata, User_parameter* user_parameter) {

	double min_loss = user_parameter->min_loss;
	double loss_tmp = 0.0;
	int cnt = 0;
    int n_action = nn_data->batch_size;
	for (int pm = 0; pm < n_label; pm++) {
		for (int i = 0; i < n_action; i++){
			// 1 for safe / 0 for unsafe
			// geloss_tmp = -nn_data->objective[pm][i] - user_parameter->beta * log(std::min(user_parameter->th / indata->postprob[pm][nn_data->normalclass[pm][i]][0], (double)1.)); //std::min(user_parameter->th - indata->postprob[pm][nn_data->normalclass[pm][i]][0], (double)0.);
			loss_tmp = -nn_data->objective[pm][i] - user_parameter->beta * (log(std::min(user_parameter->th / std::min(indata->postprob[pm][0][0], indata->postprob[pm][1][0]), (double)1.)) - indata->postprob[pm][0][1] / 2);
			if (loss_tmp < min_loss) {
				min_loss = loss_tmp;
			}
		}
		for (int i = 0; i < n_action; i++) {
			if (indata->postprob[pm][nn_data->normalclass[pm][i]][0] < std::min(indata->postprob[pm][0][0], indata->postprob[pm][1][0])+1e-9) cnt++;
		}
	}
    return min_loss-w_cnt*cnt;
}